//MatrixMulFma.cs

using System;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics.X86;
using static Util;
using size_t = nuint;
using m256d = System.Runtime.Intrinsics.Vector256<double>;

unsafe struct MatrixMulFma : IMatrixMulSimdSpecific
{
	public size_t BlockRows => 228;
	public size_t BlockCols => 128;
	public size_t KernelRows => 6;
	public size_t KernelCols => 8;

	MatrixMulAvx mmAvx;

	public void PackBlock(double* dst, double* src, size_t src_rows, size_t src_cols, size_t src_stride)
	{
		MatrixMulAvx.PackBlockBy8Cols(dst, src, src_rows, src_cols, src_stride);
	}

	public void PackVPanel(double* dst, double* src, size_t src_rows, size_t src_cols, size_t src_stride)
	{
		MatrixMulAvx.PackVPanelBy6Rows(dst, src, src_rows, src_cols, src_stride);
	}

	delegate*<bool, double*, double*, double*, size_t, size_t, void> kernel6x8;

	[MethodImpl(MethodImplOptions.AggressiveInlining)]
	public void MulKernelComplete(bool add, double* C, double* A, double* B, size_t A_cols, size_t C_stride)
	{
		if (kernel6x8 == null)
			InitKernel();
		kernel6x8(add, C, A, B, A_cols, C_stride);
	}

	[MethodImpl(MethodImplOptions.AggressiveInlining)]
	public void MulKernelIncomplete(bool add, double* C, double* A, double* B, size_t C_rows, size_t C_cols, size_t A_cols, size_t C_stride)
	{
		mmAvx.MulKernelIncomplete(add, C, A, B, C_rows, C_cols, A_cols, C_stride);
	}

	[MethodImpl(MethodImplOptions.NoInlining)]
	void InitKernel()
	{
		if (IsX64 && IsWindows)
			kernel6x8 = (delegate*<bool, double*, double*, double*, size_t, size_t, void>)FmaKernel6x8Win64.kernel;
		else if (IsX64 && IsLinux)
			kernel6x8 = (delegate*<bool, double*, double*, double*, size_t, size_t, void>)FmaKernel6x8Linux64.kernel;
		else
			kernel6x8 = &MulKernel6x8Managed;
	}

	[MethodImpl(MethodImplOptions.AggressiveOptimization)]
	static void MulKernel6x8Managed(bool add, double* C, double* A, double* B, size_t A_cols, size_t C_stride)
	{
		m256d sum00, sum01, sum10, sum11,
			  sum20, sum21, sum30, sum31,
			  sum40, sum41, sum50, sum51;

		sum00 = sum01 = sum10 = sum11 = sum20 = sum21
			  = sum30 = sum31 = sum40 = sum41 = sum50 = sum51
			  = m256d.Zero;

		double* a0 = A, b0 = B;
		for (size_t k = 0; k < A_cols; ++k) {
			const size_t prefetch_offset = 64*12;
			Sse.Prefetch0((char*)b0 + prefetch_offset);
			Sse.Prefetch0((char*)a0 + prefetch_offset);

			m256d va, vb0, vb1;
			va = Avx.BroadcastScalarToVector256(a0 + 0);
			sum00 = Fma.MultiplyAdd(Avx.LoadVector256(b0 + 0), va, sum00);
			sum01 = Fma.MultiplyAdd(Avx.LoadVector256(b0 + 4), va, sum01);

			va = Avx.BroadcastScalarToVector256(a0 + 1);
			sum10 = Fma.MultiplyAdd(Avx.LoadVector256(b0 + 0), va, sum10);
			sum11 = Fma.MultiplyAdd(Avx.LoadVector256(b0 + 4), va, sum11);


			va = Avx.BroadcastScalarToVector256(a0 + 2);
			sum20 = Fma.MultiplyAdd(Avx.LoadVector256(b0 + 0), va, sum20);
			sum21 = Fma.MultiplyAdd(Avx.LoadVector256(b0 + 4), va, sum21);

			vb0 = Avx.LoadVector256(b0);
			vb1 = Avx.LoadVector256(b0 + 4);

			va = Avx.BroadcastScalarToVector256(a0 + 3);
			sum30 = Fma.MultiplyAdd(vb0, va, sum30);
			sum31 = Fma.MultiplyAdd(vb1, va, sum31);

			va = Avx.BroadcastScalarToVector256(a0 + 4);
			sum40 = Fma.MultiplyAdd(vb0, va, sum40);
			sum41 = Fma.MultiplyAdd(vb1, va, sum41);

			va = Avx.BroadcastScalarToVector256(a0 + 5);
			sum50 = Fma.MultiplyAdd(vb0, va, sum50);
			sum51 = Fma.MultiplyAdd(vb1, va, sum51);

			a0 += 6;
			b0 += 8;
		}

		double* c0, c1, c2, c3, c4, c5;
		c0 = C + C_stride * 0;
		c1 = C + C_stride * 1;
		c2 = C + C_stride * 2;
		c3 = C + C_stride * 3;
		c4 = C + C_stride * 4;
		c5 = C + C_stride * 5;

		if (add) {
			sum00 = Avx.Add(sum00, Avx.LoadVector256(c0 + 0));
			sum01 = Avx.Add(sum01, Avx.LoadVector256(c0 + 4));
			sum10 = Avx.Add(sum10, Avx.LoadVector256(c1 + 0));
			sum11 = Avx.Add(sum11, Avx.LoadVector256(c1 + 4));
			sum20 = Avx.Add(sum20, Avx.LoadVector256(c2 + 0));
			sum21 = Avx.Add(sum21, Avx.LoadVector256(c2 + 4));
			sum30 = Avx.Add(sum30, Avx.LoadVector256(c3 + 0));
			sum31 = Avx.Add(sum31, Avx.LoadVector256(c3 + 4));
			sum40 = Avx.Add(sum40, Avx.LoadVector256(c4 + 0));
			sum41 = Avx.Add(sum41, Avx.LoadVector256(c4 + 4));
			sum50 = Avx.Add(sum50, Avx.LoadVector256(c5 + 0));
			sum51 = Avx.Add(sum51, Avx.LoadVector256(c5 + 4));
		}
		Avx.Store(c0 + 0, sum00);
		Avx.Store(c0 + 4, sum01);
		Avx.Store(c1 + 0, sum10);
		Avx.Store(c1 + 4, sum11);
		Avx.Store(c2 + 0, sum20);
		Avx.Store(c2 + 4, sum21);
		Avx.Store(c3 + 0, sum30);
		Avx.Store(c3 + 4, sum31);
		Avx.Store(c4 + 0, sum40);
		Avx.Store(c4 + 4, sum41);
		Avx.Store(c5 + 0, sum50);
		Avx.Store(c5 + 4, sum51);
	}

	public void Dispose()
	{
		mmAvx.Dispose();
	}
}


unsafe static class FmaKernel6x8Win64
{
	static string[] bytes = {
		"48 8B C4 48 81 EC 98 00 00 00 C5 F8 29 70 E8 4C 8B D2 48 8B 94 24 C0 00 00 00 4D 8B D8 C5 F8 29",
		"78 D8 45 33 C0 C5 78 29 40 C8 C5 78 29 48 B8 C5 78 29 50 A8 C5 78 29 58 98 C5 78 29 60 88 C5 78",
		"29 6C 24 10 C5 78 29 34 24 48 89 78 F8 0F B6 F9 48 8B C2 C5 E1 57 DB C5 D9 57 E4 C5 D1 57 ED C5",
		"C9 57 F6 C5 C1 57 FF C4 41 39 57 C0 C4 41 31 57 C9 C4 41 29 57 D2 C4 41 21 57 DB C4 41 19 57 E4",
		"C4 41 11 57 ED C4 41 09 57 F6 48 83 E0 FC 0F 86 32 02 00 00 48 FF C8 48 89 9C 24 A0 00 00 00 48",
		"C1 E8 02 49 8D 4B 08 48 FF C0 49 8D 59 20 4C 8D 04 85 00 00 00 00 66 66 0F 1F 84 00 00 00 00 00",
		"0F 18 89 F8 02 00 00 0F 18 8B E0 02 00 00 C5 FD 10 4B E0 C5 FD 10 13 C4 E2 7D 19 41 F8 C4 E2 FD",
		"B8 D9 C4 E2 FD B8 E2 C4 E2 7D 19 01 C4 E2 FD B8 E9 C4 E2 FD B8 F2 C4 E2 7D 19 41 08 C4 E2 FD B8",
		"F9 C4 62 FD B8 C2 C4 E2 7D 19 41 10 C4 62 FD B8 C9 C4 62 FD B8 D2 C4 E2 7D 19 41 18 C4 62 FD B8",
		"D9 C4 62 FD B8 E2 C4 E2 7D 19 41 20 C4 62 FD B8 E9 C4 62 FD B8 F2 0F 18 89 38 03 00 00 0F 18 8B",
		"20 03 00 00 C5 FD 10 4B 20 C5 FD 10 53 40 C4 E2 7D 19 41 28 C4 E2 FD B8 D9 C4 E2 FD B8 E2 C4 E2",
		"7D 19 41 30 C4 E2 FD B8 E9 C4 E2 FD B8 F2 C4 E2 7D 19 41 38 C4 E2 FD B8 F9 C4 62 FD B8 C2 C4 E2",
		"7D 19 41 40 C4 62 FD B8 C9 C4 62 FD B8 D2 C4 E2 7D 19 41 48 C4 62 FD B8 D9 C4 62 FD B8 E2 C4 E2",
		"7D 19 41 50 C4 62 FD B8 E9 C4 62 FD B8 F2 0F 18 89 78 03 00 00 0F 18 8B 60 03 00 00 C5 FD 10 4B",
		"60 C5 FD 10 93 80 00 00 00 C4 E2 7D 19 41 58 C4 E2 FD B8 D9 C4 E2 FD B8 E2 C4 E2 7D 19 41 60 C4",
		"E2 FD B8 E9 C4 E2 FD B8 F2 C4 E2 7D 19 41 68 C4 E2 FD B8 F9 C4 62 FD B8 C2 C4 E2 7D 19 41 70 C4",
		"62 FD B8 C9 C4 62 FD B8 D2 C4 E2 7D 19 41 78 C4 62 FD B8 D9 C4 62 FD B8 E2 C4 E2 7D 19 81 80 00",
		"00 00 C4 62 FD B8 E9 C4 62 FD B8 F2 0F 18 8B A0 03 00 00 C5 FD 10 8B A0 00 00 00 C5 FD 10 93 C0",
		"00 00 00 C4 E2 7D 19 81 88 00 00 00 48 8D 89 C0 00 00 00 48 8D 9B 00 01 00 00 C4 E2 FD B8 D9 C4",
		"E2 FD B8 E2 C4 E2 7D 19 41 D0 C4 E2 FD B8 E9 C4 E2 FD B8 F2 C4 E2 7D 19 41 D8 C4 E2 FD B8 F9 C4",
		"62 FD B8 C2 C4 E2 7D 19 41 E0 C4 62 FD B8 C9 C4 62 FD B8 D2 C4 E2 7D 19 41 E8 C4 62 FD B8 D9 C4",
		"62 FD B8 E2 C4 E2 7D 19 41 F0 C4 62 FD B8 E9 C4 62 FD B8 F2 48 83 E8 01 0F 85 02 FE FF FF 48 8B",
		"9C 24 A0 00 00 00 4C 3B C2 0F 83 98 00 00 00 4B 8D 04 40 49 8B C8 48 C1 E0 04 48 83 C0 10 48 C1",
		"E1 06 49 03 C3 49 03 C9 49 2B D0 0F 1F 44 00 00 C5 FD 10 09 C5 FD 10 51 20 C4 E2 7D 19 40 F0 48",
		"8D 40 30 48 8D 49 40 C4 E2 FD B8 D9 C4 E2 FD B8 E2 C4 E2 7D 19 40 C8 C4 E2 FD B8 E9 C4 E2 FD B8",
		"F2 C4 E2 7D 19 40 D0 C4 E2 FD B8 F9 C4 62 FD B8 C2 C4 E2 7D 19 40 D8 C4 62 FD B8 C9 C4 62 FD B8",
		"D2 C4 E2 7D 19 40 E0 C4 62 FD B8 D9 C4 62 FD B8 E2 C4 E2 7D 19 40 E8 C4 62 FD B8 E9 C4 62 FD B8",
		"F2 48 83 EA 01 75 89 48 8B 8C 24 C8 00 00 00 48 8B D1 4C 8B C1 48 C1 E2 04 49 C1 E0 05 49 03 D2",
		"48 8D 04 49 4D 03 C2 40 84 FF 4D 8D 0C C2 48 8B BC 24 90 00 00 00 48 8D 04 89 74 44 C4 C1 65 58",
		"1A C4 C1 5D 58 62 20 C4 C1 55 58 2C CA C4 C1 4D 58 74 CA 20 C5 C5 58 3A C5 3D 58 42 20 C4 41 35",
		"58 09 C4 41 2D 58 51 20 C4 41 25 58 18 C4 41 1D 58 60 20 C4 41 15 58 2C C2 C4 41 0D 58 74 C2 20",
		"C4 C1 7D 11 1A C4 C1 7D 11 62 20 C4 C1 7D 11 2C CA C4 C1 7D 11 74 CA 20 C5 FD 11 3A C5 7D 11 42",
		"20 C4 41 7D 11 09 C4 41 7D 11 51 20 C4 41 7D 11 18 C4 41 7D 11 60 20 C4 41 7D 11 2C C2 C4 41 7D",
		"11 74 C2 20 C5 F8 77 C5 F8 28 7C 24 70 4C 8D 9C 24 98 00 00 00 C4 C1 78 28 73 E8 C4 41 78 28 43",
		"C8 C4 41 78 28 4B B8 C4 41 78 28 53 A8 C4 41 78 28 5B 98 C4 41 78 28 63 88 C5 78 28 6C 24 10 C5",
		"78 28 34 24 49 8B E3 C3 90 90 90 90 90 90 90 90 90 90 90 90 90 90 90 90 90 90 90 90 90 90 90 90",
	};
	public static void* kernel;

	static FmaKernel6x8Win64()
	{
		byte[] codeBytes = ConvertBytes(bytes);
		kernel = AllocExecPages((size_t)codeBytes.Length);
		Marshal.Copy(codeBytes, 0, (IntPtr)kernel, codeBytes.Length);
	}
}

unsafe static class FmaKernel6x8Linux64
{
	static string[] bytes = {
		"53 4D 89 C3 49 83 E3 FC 0F 84 66 02 00 00 49 FF CB 49 C1 EB 02 49 FF C3 C5 F9 57 C0 45 31 D2 48",
		"89 D0 48 89 CB C5 F1 57 C9 C5 E9 57 D2 C5 E1 57 DB C5 D9 57 E4 C5 D1 57 ED C5 C9 57 F6 C4 41 31",
		"57 C9 C4 41 29 57 D2 C4 41 21 57 DB C5 C1 57 FF C4 41 39 57 C0 0F 18 88 00 03 00 00 0F 18 8B 00",
		"03 00 00 C5 7D 10 23 C5 7D 10 6B 20 C4 62 7D 19 30 C4 42 8D B8 DC C4 42 95 B8 D6 C4 62 7D 19 70",
		"08 C4 42 8D B8 CC C4 C2 95 B8 F6 C4 62 7D 19 70 10 C4 C2 8D B8 EC C4 C2 95 B8 E6 C4 62 7D 19 70",
		"18 C4 C2 8D B8 DC C4 C2 95 B8 D6 C4 62 7D 19 70 20 C4 C2 8D B8 CC C4 C2 95 B8 C6 C4 62 7D 19 70",
		"28 C4 C2 8D B8 FC C4 42 8D B8 C5 0F 18 88 40 03 00 00 0F 18 8B 40 03 00 00 C5 7D 10 63 40 C5 7D",
		"10 6B 60 C4 62 7D 19 70 30 C4 42 8D B8 DC C4 42 95 B8 D6 C4 62 7D 19 70 38 C4 42 8D B8 CC C4 C2",
		"95 B8 F6 C4 62 7D 19 70 40 C4 C2 8D B8 EC C4 C2 95 B8 E6 C4 62 7D 19 70 48 C4 C2 8D B8 DC C4 C2",
		"95 B8 D6 C4 62 7D 19 70 50 C4 C2 8D B8 CC C4 C2 95 B8 C6 C4 62 7D 19 70 58 C4 C2 8D B8 FC C4 42",
		"8D B8 C5 0F 18 88 80 03 00 00 0F 18 8B 80 03 00 00 C5 7D 10 A3 80 00 00 00 C5 7D 10 AB A0 00 00",
		"00 C4 62 7D 19 70 60 C4 42 8D B8 DC C4 42 95 B8 D6 C4 62 7D 19 70 68 C4 42 8D B8 CC C4 C2 95 B8",
		"F6 C4 62 7D 19 70 70 C4 C2 8D B8 EC C4 C2 95 B8 E6 C4 62 7D 19 70 78 C4 C2 8D B8 DC C4 C2 95 B8",
		"D6 C4 62 7D 19 B0 80 00 00 00 C4 C2 8D B8 CC C4 C2 95 B8 C6 C4 62 7D 19 B0 88 00 00 00 C4 C2 8D",
		"B8 FC C4 42 8D B8 C5 0F 18 8B C0 03 00 00 C5 7D 10 A3 C0 00 00 00 C5 7D 10 AB E0 00 00 00 C4 62",
		"7D 19 B0 90 00 00 00 C4 42 8D B8 DC C4 42 95 B8 D6 C4 62 7D 19 B0 98 00 00 00 C4 42 8D B8 CC C4",
		"C2 95 B8 F6 C4 62 7D 19 B0 A0 00 00 00 C4 C2 8D B8 EC C4 C2 95 B8 E6 C4 62 7D 19 B0 A8 00 00 00",
		"C4 C2 8D B8 DC C4 C2 95 B8 D6 C4 62 7D 19 B0 B0 00 00 00 C4 C2 8D B8 CC C4 C2 95 B8 C6 C4 62 7D",
		"19 B0 B8 00 00 00 C4 C2 8D B8 FC C4 42 8D B8 C5 49 83 C2 04 48 81 C3 00 01 00 00 48 05 C0 00 00",
		"00 49 FF CB 0F 85 EB FD FF FF 4D 39 C2 72 45 E9 D5 00 00 00 45 31 D2 C5 F9 57 C0 C5 F1 57 C9 C5",
		"E9 57 D2 C5 E1 57 DB C5 D9 57 E4 C5 D1 57 ED C5 C9 57 F6 C4 41 31 57 C9 C4 41 29 57 D2 C4 41 21",
		"57 DB C5 C1 57 FF C4 41 39 57 C0 4D 39 C2 0F 83 95 00 00 00 4D 29 D0 4C 89 D0 48 C1 E0 06 48 01",
		"C8 48 83 C0 20 4B 8D 0C 52 48 C1 E1 04 48 01 D1 48 83 C1 28 C5 7D 10 60 E0 C5 7D 10 28 C4 62 7D",
		"19 71 D8 C4 42 8D B8 DC C4 42 95 B8 D6 C4 62 7D 19 71 E0 C4 42 8D B8 CC C4 C2 95 B8 F6 C4 62 7D",
		"19 71 E8 C4 C2 8D B8 EC C4 C2 95 B8 E6 C4 62 7D 19 71 F0 C4 C2 8D B8 DC C4 C2 95 B8 D6 C4 62 7D",
		"19 71 F8 C4 C2 8D B8 CC C4 C2 95 B8 C6 C4 62 7D 19 31 C4 C2 8D B8 FC C4 42 8D B8 C5 48 83 C0 40",
		"48 83 C1 30 49 FF C8 75 8B 4B 8D 1C 09 4B 8D 14 49 4A 8D 0C 8D 00 00 00 00 4B 8D 04 89 40 84 FF",
		"74 42 C5 25 58 1E C5 2D 58 56 20 C4 21 35 58 0C CE C4 A1 4D 58 74 CE 20 C5 D5 58 2C DE C5 DD 58",
		"64 DE 20 C5 E5 58 1C D6 C5 ED 58 54 D6 20 C5 F5 58 0C CE C5 FD 58 44 CE 20 C5 C5 58 3C C6 C5 3D",
		"58 44 C6 20 C5 7D 11 1E C5 7D 11 56 20 C4 21 7D 11 0C CE C4 A1 7D 11 74 CE 20 C5 FD 11 2C DE C5",
		"FD 11 64 DE 20 C5 FD 11 1C D6 C5 FD 11 54 D6 20 C5 FD 11 0C CE C5 FD 11 44 CE 20 C5 FD 11 3C C6",
		"C5 7D 11 44 C6 20 5B C5 F8 77 C3 90 90 90 90 90 90 90 90 90 90 90 90 90 90 90 90 90 90 90 90 90"
	};

	public static void* kernel;

	static FmaKernel6x8Linux64()
	{
		byte[] codeBytes = ConvertBytes(bytes);
		kernel = AllocExecPages((size_t)codeBytes.Length);
		Marshal.Copy(codeBytes, 0, (IntPtr)kernel, codeBytes.Length);
	}
}
